{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.json\t\t\t       model.ckpt-300000.index\r\n",
      "model.ckpt-300000.data-00000-of-00001  model.ckpt-300000.meta\r\n"
     ]
    }
   ],
   "source": [
    "# !wget https://f000.backblazeb2.com/file/malaya-model/bert-bahasa/alxlnet-base-2020-04-10.tar.gz\n",
    "# !tar -zxf alxlnet-base-2020-04-10.tar.gz\n",
    "!ls alxlnet-base-2020-04-10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/alxlnet/config/alxlnet-base_config.json\n",
    "# !wget https://raw.githubusercontent.com/huseinzol05/Malaya/master/pretrained-model/preprocess/sp10m.cased.v9.model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "from xlnet import model_utils\n",
    "from xlnet.prepro_utils import preprocess_text, encode_ids\n",
    "import pickle\n",
    "import json\n",
    "import alxlnet as xlnet\n",
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentencepiece as spm\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')\n",
    "\n",
    "def tokenize_fn(text):\n",
    "    text = preprocess_text(text, lower= False)\n",
    "    return encode_ids(sp_model, text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = 'https://f000.backblazeb2.com/file/malay-dataset/tagging/ontonotes5/'\n",
    "\n",
    "urls = \"\"\"\n",
    "augmentation-address-ontonotes5.json\n",
    "augmentation-event-ontonotes5.json\n",
    "augmentation-fac-ontonotes5.json\n",
    "augmentation-gpe-ontonotes5.json\n",
    "augmentation-language-ontonotes5.json\n",
    "augmentation-law-ontonotes5.json\n",
    "augmentation-loc-ontonotes5.json\n",
    "augmentation-norp-ontonotes5.json\n",
    "augmentation-org-ontonotes5.json\n",
    "augmentation-person-ontonotes5.json\n",
    "augmentation-product-ontonotes5.json\n",
    "augmentation-work-of-art-ontonotes5.json\n",
    "ontonotes5-train-test.json\n",
    "\"\"\"\n",
    "urls = list(filter(None, urls.split('\\n')))\n",
    "\n",
    "import os\n",
    "\n",
    "# uncomment to download\n",
    "# for url in urls:\n",
    "#     print(url)\n",
    "#     os.system(f'wget {prefix}{url}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('ontonotes5-train-test.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "test_X = data['test_X']\n",
    "train_Y = data['train_Y']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = [\n",
    "    {'Tag': 'OTHER', 'Description': 'other'},\n",
    "    {'Tag': 'ADDRESS', 'Description': 'Address of physical location.'},\n",
    "    {'Tag': 'PERSON', 'Description': 'People, including fictional.'},\n",
    "    {\n",
    "        'Tag': 'NORP',\n",
    "        'Description': 'Nationalities or religious or political groups.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'FAC',\n",
    "        'Description': 'Buildings, airports, highways, bridges, etc.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'ORG',\n",
    "        'Description': 'Companies, agencies, institutions, etc.',\n",
    "    },\n",
    "    {'Tag': 'GPE', 'Description': 'Countries, cities, states.'},\n",
    "    {\n",
    "        'Tag': 'LOC',\n",
    "        'Description': 'Non-GPE locations, mountain ranges, bodies of water.',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'PRODUCT',\n",
    "        'Description': 'Objects, vehicles, foods, etc. (Not services.)',\n",
    "    },\n",
    "    {\n",
    "        'Tag': 'EVENT',\n",
    "        'Description': 'Named hurricanes, battles, wars, sports events, etc.',\n",
    "    },\n",
    "    {'Tag': 'WORK_OF_ART', 'Description': 'Titles of books, songs, etc.'},\n",
    "    {'Tag': 'LAW', 'Description': 'Named documents made into laws.'},\n",
    "    {'Tag': 'LANGUAGE', 'Description': 'Any named language.'},\n",
    "    {\n",
    "        'Tag': 'DATE',\n",
    "        'Description': 'Absolute or relative dates or periods.',\n",
    "    },\n",
    "    {'Tag': 'TIME', 'Description': 'Times smaller than a day.'},\n",
    "    {'Tag': 'PERCENT', 'Description': 'Percentage, including \"%\".'},\n",
    "    {'Tag': 'MONEY', 'Description': 'Monetary values, including unit.'},\n",
    "    {\n",
    "        'Tag': 'QUANTITY',\n",
    "        'Description': 'Measurements, as of weight or distance.',\n",
    "    },\n",
    "    {'Tag': 'ORDINAL', 'Description': '\"first\", \"second\", etc.'},\n",
    "    {\n",
    "        'Tag': 'CARDINAL',\n",
    "        'Description': 'Numerals that do not fall under another type.',\n",
    "    },\n",
    "]\n",
    "d = [d['Tag'] for d in d]\n",
    "d = ['PAD', 'X'] + d\n",
    "tag2idx = {i: no for no, i in enumerate(d)}\n",
    "idx2tag = {no: i for no, i in enumerate(d)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "augmentation-org-ontonotes5.json\n",
      "35150 35150\n",
      "6300 6300\n",
      "augmentation-fac-ontonotes5.json\n",
      "17040 17040\n",
      "4260 4260\n",
      "augmentation-loc-ontonotes5.json\n",
      "13040 13040\n",
      "2460 2460\n",
      "augmentation-gpe-ontonotes5.json\n",
      "21060 21060\n",
      "0 0\n",
      "augmentation-work-of-art-ontonotes5.json\n",
      "4020 4020\n",
      "1020 1020\n",
      "augmentation-event-ontonotes5.json\n",
      "1665 1665\n",
      "0 0\n",
      "augmentation-person-ontonotes5.json\n",
      "37883 37883\n",
      "2530 2530\n",
      "augmentation-product-ontonotes5.json\n",
      "7040 7040\n",
      "1760 1760\n",
      "augmentation-law-ontonotes5.json\n",
      "12584 12584\n",
      "3161 3161\n",
      "augmentation-language-ontonotes5.json\n",
      "6860 6860\n",
      "0 0\n",
      "augmentation-norp-ontonotes5.json\n",
      "30660 30660\n",
      "5940 5940\n",
      "augmentation-address-ontonotes5.json\n",
      "57502 57502\n",
      "15106 15106\n"
     ]
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "augmented = glob('augmentation-*-ontonotes5.json')\n",
    "\n",
    "for f in augmented:\n",
    "    print(f)\n",
    "    with open(f) as fopen:\n",
    "        data = json.load(fopen)\n",
    "        \n",
    "    print(len(data.get('train_X', [])), len(data.get('train_Y', [])))\n",
    "    print(len(data.get('test_X', [])), len(data.get('test_Y', [])))\n",
    "    \n",
    "    train_X.extend(data.get('train_X', []))\n",
    "    train_Y.extend(data.get('train_Y', []))\n",
    "    test_X.extend(data.get('test_X', []))\n",
    "    test_Y.extend(data.get('test_Y', []))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "SEG_ID_A   = 0\n",
    "SEG_ID_B   = 1\n",
    "SEG_ID_CLS = 2\n",
    "SEG_ID_SEP = 3\n",
    "SEG_ID_PAD = 4\n",
    "\n",
    "special_symbols = {\n",
    "    \"<unk>\"  : 0,\n",
    "    \"<s>\"    : 1,\n",
    "    \"</s>\"   : 2,\n",
    "    \"<cls>\"  : 3,\n",
    "    \"<sep>\"  : 4,\n",
    "    \"<pad>\"  : 5,\n",
    "    \"<mask>\" : 6,\n",
    "    \"<eod>\"  : 7,\n",
    "    \"<eop>\"  : 8,\n",
    "}\n",
    "\n",
    "VOCAB_SIZE = 32000\n",
    "UNK_ID = special_symbols[\"<unk>\"]\n",
    "CLS_ID = special_symbols[\"<cls>\"]\n",
    "SEP_ID = special_symbols[\"<sep>\"]\n",
    "MASK_ID = special_symbols[\"<mask>\"]\n",
    "EOD_ID = special_symbols[\"<eod>\"]\n",
    "\n",
    "def XY(strings):\n",
    "    left_train, right_train = strings[0], strings[1]\n",
    "    X, Y, segments, masks = [], [], [], []\n",
    "    for i in tqdm(range(len(left_train))):\n",
    "        left = left_train[i]\n",
    "        right = right_train[i]\n",
    "        bert_tokens = []\n",
    "        y = []\n",
    "        for no, orig_token in enumerate(left):\n",
    "            t = tokenize_fn(orig_token)\n",
    "            bert_tokens.extend(t)\n",
    "            if len(t):\n",
    "                y.append(right[no])\n",
    "            y.extend(['X'] * (len(t) - 1))\n",
    "        bert_tokens.extend([4, 3])\n",
    "        segment = [0] * (len(bert_tokens) - 1) + [SEG_ID_CLS]\n",
    "        input_mask = [0] * len(segment)\n",
    "        y.extend(['PAD', 'PAD'])\n",
    "        y = [tag2idx[i] for i in y]\n",
    "        if len(bert_tokens) != len(y):\n",
    "            print(i)\n",
    "        X.append(bert_tokens)\n",
    "        Y.append(y)\n",
    "        segments.append(segment)\n",
    "        masks.append(input_mask)\n",
    "    return [(X, Y, segments, masks)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40812/40812 [00:37<00:00, 1078.20it/s]\n",
      "100%|██████████| 1/1 [00:00<00:00, 734.81it/s].64it/s]]\n",
      "100%|██████████| 40812/40812 [00:37<00:00, 1082.87it/s]\n",
      "100%|██████████| 40812/40812 [00:36<00:00, 1118.36it/s]\n",
      " 84%|████████▎ | 34172/40812 [00:31<00:06, 1088.15it/s]\n",
      "100%|██████████| 40812/40812 [00:37<00:00, 1077.53it/s]\n",
      "100%|██████████| 40812/40812 [00:37<00:00, 1080.94it/s]\n",
      "100%|██████████| 40812/40812 [00:37<00:00, 1096.70it/s]\n",
      " 75%|███████▌  | 30654/40812 [00:30<00:09, 1066.82it/s]\n",
      "100%|██████████| 40812/40812 [00:38<00:00, 1062.10it/s]\n",
      "100%|██████████| 40812/40812 [00:37<00:00, 1094.03it/s]\n",
      "100%|██████████| 40812/40812 [00:38<00:00, 1058.92it/s]\n",
      "100%|██████████| 40812/40812 [00:39<00:00, 1039.90it/s]\n",
      "100%|██████████| 40812/40812 [00:39<00:00, 1043.74it/s]\n",
      "100%|██████████| 40812/40812 [00:38<00:00, 1071.58it/s]\n",
      "100%|██████████| 40812/40812 [00:41<00:00, 979.16it/s] \n",
      "100%|██████████| 40812/40812 [00:40<00:00, 996.73it/s] \n"
     ]
    }
   ],
   "source": [
    "import cleaning\n",
    "t = cleaning.multiprocessing(train_X, train_Y, XY)\n",
    "train_X, train_Y, train_segments, train_masks = [], [], [], []\n",
    "for x, y, s, m in t:\n",
    "    train_X.extend(x)\n",
    "    train_Y.extend(y)\n",
    "    train_segments.extend(s)\n",
    "    train_masks.extend(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5494/5494 [00:09<00:00, 552.30it/s] \n",
      "100%|██████████| 10/10 [00:00<00:00, 1214.23it/s]/s]]\n",
      "100%|██████████| 5494/5494 [00:10<00:00, 528.50it/s] \n",
      " 67%|██████▋   | 3679/5494 [00:08<00:01, 1135.57it/s]\n",
      "100%|██████████| 5494/5494 [00:10<00:00, 527.77it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 544.30it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 524.27it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 515.78it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 534.70it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 521.69it/s] \n",
      " 97%|█████████▋| 5335/5494 [00:10<00:00, 1172.21it/s]\n",
      "100%|██████████| 5494/5494 [00:10<00:00, 531.23it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 531.53it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 526.53it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 527.54it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 509.61it/s] \n",
      "100%|██████████| 5494/5494 [00:10<00:00, 519.75it/s] \n"
     ]
    }
   ],
   "source": [
    "t = cleaning.multiprocessing(test_X, test_Y, XY)\n",
    "test_X, test_Y, test_segments, test_masks = [], [], [], []\n",
    "for x, y, s, m in t:\n",
    "    test_X.extend(x)\n",
    "    test_Y.extend(y)\n",
    "    test_segments.extend(s)\n",
    "    test_masks.extend(m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.utils import shuffle\n",
    "\n",
    "train_X, train_Y, train_masks, train_segments = shuffle(train_X, train_Y, train_masks, train_segments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet.py:70: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.1,\n",
    "      dropatt=0.1,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='alxlnet-base-2020-04-10/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "102030 10203\n"
     ]
    }
   ],
   "source": [
    "epoch = 5\n",
    "batch_size = 32\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(train_X) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)\n",
    "print(num_train_steps, num_warmup_steps)\n",
    "\n",
    "training_parameters = dict(\n",
    "      decay_method = 'poly',\n",
    "      train_steps = num_train_steps,\n",
    "      learning_rate = 2e-5,\n",
    "      warmup_steps = num_warmup_steps,\n",
    "      min_lr_ratio = 0.0,\n",
    "      weight_decay = 0.00,\n",
    "      adam_epsilon = 1e-8,\n",
    "      num_core_per_host = 1,\n",
    "      lr_layer_decay_rate = 1,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.02,\n",
    "      clip = 1.0,\n",
    "      clamp_len=-1,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Parameter:\n",
    "    def __init__(self, decay_method, warmup_steps, weight_decay, adam_epsilon, \n",
    "                num_core_per_host, lr_layer_decay_rate, use_tpu, learning_rate, train_steps,\n",
    "                min_lr_ratio, clip, **kwargs):\n",
    "        self.decay_method = decay_method\n",
    "        self.warmup_steps = warmup_steps\n",
    "        self.weight_decay = weight_decay\n",
    "        self.adam_epsilon = adam_epsilon\n",
    "        self.num_core_per_host = num_core_per_host\n",
    "        self.lr_layer_decay_rate = lr_layer_decay_rate\n",
    "        self.use_tpu = use_tpu\n",
    "        self.learning_rate = learning_rate\n",
    "        self.train_steps = train_steps\n",
    "        self.min_lr_ratio = min_lr_ratio\n",
    "        self.clip = clip\n",
    "        \n",
    "training_parameters = Parameter(**training_parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.lengths = tf.count_nonzero(self.X, 1)\n",
    "        self.maxlen = tf.shape(self.X)[1]\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        output_layer = xlnet_model.get_sequence_output()\n",
    "        output_layer = tf.transpose(output_layer, [1, 0, 2])\n",
    "        \n",
    "        logits = tf.layers.dense(output_layer, dimension_output)\n",
    "        y_t = self.Y\n",
    "        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(\n",
    "            logits, y_t, self.lengths\n",
    "        )\n",
    "        self.cost = tf.reduce_mean(-log_likelihood)\n",
    "        self.optimizer = tf.train.AdamOptimizer(\n",
    "            learning_rate = learning_rate\n",
    "        ).minimize(self.cost)\n",
    "        mask = tf.sequence_mask(self.lengths, maxlen = self.maxlen)\n",
    "        self.tags_seq, tags_score = tf.contrib.crf.crf_decode(\n",
    "            logits, transition_params, self.lengths\n",
    "        )\n",
    "        self.tags_seq = tf.identity(self.tags_seq, name = 'logits')\n",
    "\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(self.tags_seq, mask)\n",
    "        mask_label = tf.boolean_mask(y_t, mask)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet.py:253: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet.py:253: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet_modeling.py:696: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet_modeling.py:703: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet_modeling.py:808: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/malaya/Malaya/session/entities/alxlnet_modeling.py:109: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/contrib/crf/python/ops/crf.py:99: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/contrib/crf/python/ops/crf.py:213: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(tag2idx)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "checkpoint = 'alxlnet-base-2020-04-10/model.ckpt-300000'\n",
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-2020-04-10/model.ckpt-300000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def merge_sentencepiece_tokens_tagging(x, y):\n",
    "    new_paired_tokens = []\n",
    "    n_tokens = len(x)\n",
    "    rejected = ['<cls>', '<sep>']\n",
    "\n",
    "    i = 0\n",
    "\n",
    "    while i < n_tokens:\n",
    "\n",
    "        current_token, current_label = x[i], y[i]\n",
    "        if not current_token.startswith('▁') and current_token not in rejected:\n",
    "            previous_token, previous_label = new_paired_tokens.pop()\n",
    "            merged_token = previous_token\n",
    "            merged_label = [previous_label]\n",
    "            while (\n",
    "                not current_token.startswith('▁')\n",
    "                and current_token not in rejected\n",
    "            ):\n",
    "                merged_token = merged_token + current_token.replace('▁', '')\n",
    "                merged_label.append(current_label)\n",
    "                i = i + 1\n",
    "                current_token, current_label = x[i], y[i]\n",
    "            merged_label = merged_label[0]\n",
    "            new_paired_tokens.append((merged_token, merged_label))\n",
    "\n",
    "        else:\n",
    "            new_paired_tokens.append((current_token, current_label))\n",
    "            i = i + 1\n",
    "\n",
    "    words = [\n",
    "        i[0].replace('▁', '')\n",
    "        for i in new_paired_tokens\n",
    "        if i[0] not in ['<cls>', '<sep>']\n",
    "    ]\n",
    "    labels = [i[1] for i in new_paired_tokens if i[0] not in ['<cls>', '<sep>']]\n",
    "    return words, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = 'KUALA LUMPUR: Sempena sambutan Aidilfitri minggu depan, Perdana Menteri Tun Dr Mahathir Mohamad dan Menteri Pengangkutan Anthony Loke Siew Fook menitipkan pesanan khas kepada orang ramai yang mahu pulang ke kampung halaman masing-masing. Dalam video pendek terbitan Jabatan Keselamatan Jalan Raya (JKJR) itu, Dr Mahathir menasihati mereka supaya berhenti berehat dan tidur sebentar  sekiranya mengantuk ketika memandu.'\n",
    "\n",
    "import re\n",
    "\n",
    "def entities_textcleaning(string, lowering = False):\n",
    "    \"\"\"\n",
    "    use by entities recognition, pos recognition and dependency parsing\n",
    "    \"\"\"\n",
    "    string = re.sub('[^A-Za-z0-9\\-\\/() ]+', ' ', string)\n",
    "    string = re.sub(r'[ ]+', ' ', string).strip()\n",
    "    original_string = string.split()\n",
    "    if lowering:\n",
    "        string = string.lower()\n",
    "    string = [\n",
    "        (original_string[no], word.title() if word.isupper() else word)\n",
    "        for no, word in enumerate(string.split())\n",
    "        if len(word)\n",
    "    ]\n",
    "    return [s[0] for s in string], [s[1] for s in string]\n",
    "\n",
    "def parse_X(left):\n",
    "    bert_tokens = []\n",
    "    for no, orig_token in enumerate(left):\n",
    "        t = tokenize_fn(orig_token)\n",
    "        bert_tokens.extend(t)\n",
    "    bert_tokens.extend([4, 3])\n",
    "    segment = [0] * (len(bert_tokens) - 1) + [SEG_ID_CLS]\n",
    "    input_mask = [0] * len(segment)\n",
    "    s_tokens = [sp_model.IdToPiece(i) for i in bert_tokens]\n",
    "    return bert_tokens, segment, input_mask, s_tokens\n",
    "\n",
    "sequence = entities_textcleaning(string)[1]\n",
    "parsed_sequence, segment_sequence, mask_sequence, xlnet_sequence = parse_X(sequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Kuala', 'ADDRESS'),\n",
       " ('Lumpur', 'ORG'),\n",
       " ('Sempena', 'LAW'),\n",
       " ('sambutan', 'ADDRESS'),\n",
       " ('Aidilfitri', 'ADDRESS'),\n",
       " ('minggu', 'EVENT'),\n",
       " ('depan', 'NORP'),\n",
       " ('Perdana', 'NORP'),\n",
       " ('Menteri', 'OTHER'),\n",
       " ('Tun', 'FAC'),\n",
       " ('Dr', 'NORP'),\n",
       " ('Mahathir', 'OTHER'),\n",
       " ('Mohamad', 'PRODUCT'),\n",
       " ('dan', 'FAC'),\n",
       " ('Menteri', 'FAC'),\n",
       " ('Pengangkutan', 'ADDRESS'),\n",
       " ('Anthony', 'EVENT'),\n",
       " ('Loke', 'EVENT'),\n",
       " ('Siew', 'LAW'),\n",
       " ('Fook', 'EVENT'),\n",
       " ('menitipkan', 'EVENT'),\n",
       " ('pesanan', 'EVENT'),\n",
       " ('khas', 'FAC'),\n",
       " ('kepada', 'X'),\n",
       " ('orang', 'PRODUCT'),\n",
       " ('ramai', 'PRODUCT'),\n",
       " ('yang', 'PRODUCT'),\n",
       " ('mahu', 'OTHER'),\n",
       " ('pulang', 'FAC'),\n",
       " ('ke', 'ADDRESS'),\n",
       " ('kampung', 'NORP'),\n",
       " ('halaman', 'TIME'),\n",
       " ('masing-masing', 'EVENT'),\n",
       " ('Dalam', 'LOC'),\n",
       " ('video', 'PAD'),\n",
       " ('pendek', 'PAD'),\n",
       " ('terbitan', 'EVENT'),\n",
       " ('Jabatan', 'EVENT'),\n",
       " ('Keselamatan', 'ORDINAL'),\n",
       " ('Jalan', 'EVENT'),\n",
       " ('Raya', 'EVENT'),\n",
       " ('(Jkjr)', 'WORK_OF_ART'),\n",
       " ('itu', 'WORK_OF_ART'),\n",
       " ('Dr', 'MONEY'),\n",
       " ('Mahathir', 'LANGUAGE'),\n",
       " ('menasihati', 'X'),\n",
       " ('mereka', 'ADDRESS'),\n",
       " ('supaya', 'CARDINAL'),\n",
       " ('berhenti', 'ADDRESS'),\n",
       " ('berehat', 'ADDRESS'),\n",
       " ('dan', 'FAC'),\n",
       " ('tidur', 'PERSON'),\n",
       " ('sebentar', 'PRODUCT'),\n",
       " ('sekiranya', 'TIME'),\n",
       " ('mengantuk', 'PERSON'),\n",
       " ('ketika', 'PAD'),\n",
       " ('memandu', 'PAD')]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted = sess.run(model.tags_seq,\n",
    "                feed_dict = {\n",
    "                    model.X: [parsed_sequence],\n",
    "                    model.segment_ids: [segment_sequence],\n",
    "                    model.input_masks: [mask_sequence],\n",
    "                },\n",
    "        )[0]\n",
    "merged = merge_sentencepiece_tokens_tagging(xlnet_sequence, [idx2tag[d] for d in predicted])\n",
    "list(zip(merged[0], merged[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20407/20407 [2:27:38<00:00,  2.30it/s, accuracy=1, cost=0.114]     \n",
      "test minibatch loop: 100%|██████████| 2748/2748 [09:07<00:00,  5.02it/s, accuracy=0.995, cost=1.53] \n",
      "train minibatch loop:   0%|          | 0/20407 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 9405.782040596008\n",
      "epoch: 0, training loss: 4.537313, training acc: 0.979145, valid loss: 4.936748, valid acc: 0.981108\n",
      "\n",
      "[('Kuala', 'GPE'), ('Lumpur', 'GPE'), ('Sempena', 'OTHER'), ('sambutan', 'EVENT'), ('Aidilfitri', 'EVENT'), ('minggu', 'OTHER'), ('depan', 'OTHER'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'ORG'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'PERSON'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing-masing', 'OTHER'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('Jalan', 'ORG'), ('Raya', 'ORG'), ('(Jkjr)', 'ORG'), ('itu', 'OTHER'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'OTHER'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'OTHER'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 20407/20407 [2:27:33<00:00,  2.30it/s, accuracy=1, cost=0.0225]     \n",
      "test minibatch loop: 100%|██████████| 2748/2748 [09:04<00:00,  5.05it/s, accuracy=0.996, cost=1.54] \n",
      "train minibatch loop:   0%|          | 0/20407 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 9398.288618326187\n",
      "epoch: 1, training loss: 1.231279, training acc: 0.993845, valid loss: 5.231983, valid acc: 0.983887\n",
      "\n",
      "[('Kuala', 'GPE'), ('Lumpur', 'GPE'), ('Sempena', 'OTHER'), ('sambutan', 'EVENT'), ('Aidilfitri', 'EVENT'), ('minggu', 'DATE'), ('depan', 'DATE'), ('Perdana', 'OTHER'), ('Menteri', 'OTHER'), ('Tun', 'PERSON'), ('Dr', 'PERSON'), ('Mahathir', 'PERSON'), ('Mohamad', 'PERSON'), ('dan', 'OTHER'), ('Menteri', 'OTHER'), ('Pengangkutan', 'ORG'), ('Anthony', 'PERSON'), ('Loke', 'PERSON'), ('Siew', 'PERSON'), ('Fook', 'PERSON'), ('menitipkan', 'OTHER'), ('pesanan', 'OTHER'), ('khas', 'OTHER'), ('kepada', 'OTHER'), ('orang', 'OTHER'), ('ramai', 'OTHER'), ('yang', 'OTHER'), ('mahu', 'OTHER'), ('pulang', 'OTHER'), ('ke', 'OTHER'), ('kampung', 'OTHER'), ('halaman', 'OTHER'), ('masing-masing', 'OTHER'), ('Dalam', 'OTHER'), ('video', 'OTHER'), ('pendek', 'OTHER'), ('terbitan', 'OTHER'), ('Jabatan', 'ORG'), ('Keselamatan', 'ORG'), ('Jalan', 'ORG'), ('Raya', 'ORG'), ('(Jkjr)', 'ORG'), ('itu', 'OTHER'), ('Dr', 'OTHER'), ('Mahathir', 'PERSON'), ('menasihati', 'OTHER'), ('mereka', 'OTHER'), ('supaya', 'OTHER'), ('berhenti', 'OTHER'), ('berehat', 'OTHER'), ('dan', 'OTHER'), ('tidur', 'OTHER'), ('sebentar', 'OTHER'), ('sekiranya', 'OTHER'), ('mengantuk', 'OTHER'), ('ketika', 'OTHER'), ('memandu', 'OTHER')]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:   0%|          | 14/20407 [00:06<2:34:30,  2.20it/s, accuracy=0.997, cost=0.5]  \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-31-1f6d6c8506d4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     23\u001b[0m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msegment_ids\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_segments\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m                 \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minput_masks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbatch_masks\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m             },\n\u001b[1;32m     27\u001b[0m         )\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m    954\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    955\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 956\u001b[0;31m                          run_metadata_ptr)\n\u001b[0m\u001b[1;32m    957\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    958\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1178\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1179\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1180\u001b[0;31m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m   1181\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1182\u001b[0m       \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m   1357\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1358\u001b[0m       return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1359\u001b[0;31m                            run_metadata)\n\u001b[0m\u001b[1;32m   1360\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1361\u001b[0m       \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m   1363\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1364\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1365\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1366\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1367\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m   1348\u001b[0m       \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1349\u001b[0m       return self._call_tf_sessionrun(options, feed_dict, fetch_list,\n\u001b[0;32m-> 1350\u001b[0;31m                                       target_list, run_metadata)\n\u001b[0m\u001b[1;32m   1351\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1352\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m   1441\u001b[0m     return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,\n\u001b[1;32m   1442\u001b[0m                                             \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1443\u001b[0;31m                                             run_metadata)\n\u001b[0m\u001b[1;32m   1444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1445\u001b[0m   \u001b[0;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "for e in range(epoch):\n",
    "    lasttime = time.time()\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = train_X[i : index]\n",
    "        batch_y = train_Y[i : index]\n",
    "        batch_masks = train_masks[i : index]\n",
    "        batch_segments = train_segments[i : index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_y = pad_sequences(batch_y, padding='post')\n",
    "        batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "        batch_masks = pad_sequences(batch_masks, padding='post', value = 1)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'test minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = test_X[i : index]\n",
    "        batch_y = test_Y[i : index]\n",
    "        batch_masks = test_masks[i : index]\n",
    "        batch_segments = test_segments[i : index]\n",
    "        batch_x = pad_sequences(batch_x, padding='post')\n",
    "        batch_y = pad_sequences(batch_y, padding='post')\n",
    "        batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "        batch_masks = pad_sequences(batch_masks, padding='post', value = 1)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.X: batch_x,\n",
    "                model.Y: batch_y,\n",
    "                model.segment_ids: batch_segments,\n",
    "                model.input_masks: batch_masks,\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "\n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (e, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    \n",
    "    predicted = sess.run(model.tags_seq,\n",
    "                feed_dict = {\n",
    "                    model.X: [parsed_sequence],\n",
    "                    model.segment_ids: [segment_sequence],\n",
    "                    model.input_masks: [mask_sequence],\n",
    "                },\n",
    "        )[0]\n",
    "    merged = merge_sentencepiece_tokens_tagging(xlnet_sequence, [idx2tag[d] for d in predicted])\n",
    "    print(list(zip(merged[0], merged[1])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'alxlnet-base-entities/model.ckpt'"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'alxlnet-base-entities/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = dict(\n",
    "      is_training=False,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='alxlnet-base-2020-04-10/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    }
   ],
   "source": [
    "dimension_output = len(tag2idx)\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-entities/model.ckpt\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.restore(sess, 'alxlnet-base-entities/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred2label(pred):\n",
    "    out = []\n",
    "    for pred_i in pred:\n",
    "        out_i = []\n",
    "        for p in pred_i:\n",
    "            out_i.append(idx2tag[p])\n",
    "        out.append(out_i)\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "validation minibatch loop: 100%|██████████| 2748/2748 [08:31<00:00,  5.37it/s]\n"
     ]
    }
   ],
   "source": [
    "real_Y, predict_Y = [], []\n",
    "\n",
    "pbar = tqdm(\n",
    "    range(0, len(test_X), batch_size), desc = 'validation minibatch loop'\n",
    ")\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = test_X[i : index]\n",
    "    batch_y = test_Y[i : index]\n",
    "    batch_masks = test_masks[i : index]\n",
    "    batch_segments = test_segments[i : index]\n",
    "    batch_x = pad_sequences(batch_x, padding='post')\n",
    "    batch_y = pad_sequences(batch_y, padding='post')\n",
    "    batch_segments = pad_sequences(batch_segments, padding='post', value = 4)\n",
    "    batch_masks = pad_sequences(batch_masks, padding='post', value = 1)\n",
    "    predicted = pred2label(sess.run(\n",
    "        model.tags_seq,\n",
    "        feed_dict = {\n",
    "            model.X: batch_x,\n",
    "            model.segment_ids: batch_segments,\n",
    "            model.input_masks: batch_masks,\n",
    "        },\n",
    "    ))\n",
    "    real = pred2label(batch_y)\n",
    "    predict_Y.extend(predicted)\n",
    "    real_Y.extend(real)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_real_Y = []\n",
    "for r in real_Y:\n",
    "    temp_real_Y.extend(r)\n",
    "    \n",
    "temp_predict_Y = []\n",
    "for r in predict_Y:\n",
    "    temp_predict_Y.extend(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "     ADDRESS    0.99949   0.99981   0.99965     93446\n",
      "    CARDINAL    0.92765   0.91402   0.92078     48255\n",
      "        DATE    0.95309   0.93386   0.94338    126548\n",
      "       EVENT    0.88426   0.93241   0.90770      5711\n",
      "         FAC    0.92367   0.92991   0.92678     27392\n",
      "         GPE    0.93880   0.95315   0.94592    101357\n",
      "    LANGUAGE    0.84296   0.90909   0.87478       803\n",
      "         LAW    0.95472   0.95091   0.95281     24834\n",
      "         LOC    0.92551   0.92953   0.92751     34538\n",
      "       MONEY    0.86719   0.87403   0.87060     30032\n",
      "        NORP    0.95470   0.89518   0.92398     57014\n",
      "     ORDINAL    0.86582   0.94721   0.90469      6213\n",
      "         ORG    0.95300   0.93138   0.94206    219533\n",
      "       OTHER    0.99080   0.99334   0.99206   3553350\n",
      "         PAD    0.99992   1.00000   0.99996   1292031\n",
      "     PERCENT    0.96856   0.96851   0.96853     21722\n",
      "      PERSON    0.94616   0.96716   0.95655    101981\n",
      "     PRODUCT    0.87820   0.79333   0.83361     11124\n",
      "    QUANTITY    0.94752   0.91872   0.93290     11614\n",
      "        TIME    0.90322   0.90949   0.90635      9502\n",
      " WORK_OF_ART    0.88971   0.79732   0.84098     13800\n",
      "           X    0.99883   0.99891   0.99887   1349384\n",
      "\n",
      "    accuracy                        0.98816   7140184\n",
      "   macro avg    0.93244   0.92942   0.93047   7140184\n",
      "weighted avg    0.98810   0.98816   0.98810   7140184\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn.metrics import classification_report\n",
    "print(classification_report(temp_real_Y, temp_predict_Y, digits = 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Placeholder',\n",
       " 'Placeholder_1',\n",
       " 'Placeholder_2',\n",
       " 'Placeholder_3',\n",
       " 'model/transformer/r_w_bias',\n",
       " 'model/transformer/r_r_bias',\n",
       " 'model/transformer/word_embedding/lookup_table',\n",
       " 'model/transformer/word_embedding/lookup_table_2',\n",
       " 'model/transformer/r_s_bias',\n",
       " 'model/transformer/seg_embed',\n",
       " 'model/transformer/layer_shared/rel_attn/q/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/k/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/v/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/r/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/o/kernel',\n",
       " 'model/transformer/layer_shared/rel_attn/LayerNorm/gamma',\n",
       " 'model/transformer/layer_shared/ff/layer_1/kernel',\n",
       " 'model/transformer/layer_shared/ff/layer_1/bias',\n",
       " 'model/transformer/layer_shared/ff/layer_2/kernel',\n",
       " 'model/transformer/layer_shared/ff/layer_2/bias',\n",
       " 'model/transformer/layer_shared/ff/LayerNorm/gamma',\n",
       " 'dense/kernel',\n",
       " 'dense/bias',\n",
       " 'transitions',\n",
       " 'logits']"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "strings = ','.join(\n",
    "    [\n",
    "        n.name\n",
    "        for n in tf.get_default_graph().as_graph_def().node\n",
    "        if ('Variable' in n.op\n",
    "        or 'Placeholder' in n.name\n",
    "        or 'logits' in n.name\n",
    "        or 'alphas' in n.name\n",
    "        or 'self/Softmax' in n.name)\n",
    "        and 'Adam' not in n.name\n",
    "        and 'beta' not in n.name\n",
    "        and 'global_step' not in n.name\n",
    "    ]\n",
    ")\n",
    "strings.split(',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def freeze_graph(model_dir, output_node_names):\n",
    "\n",
    "    if not tf.gfile.Exists(model_dir):\n",
    "        raise AssertionError(\n",
    "            \"Export directory doesn't exists. Please specify an export \"\n",
    "            'directory: %s' % model_dir\n",
    "        )\n",
    "\n",
    "    checkpoint = tf.train.get_checkpoint_state(model_dir)\n",
    "    input_checkpoint = checkpoint.model_checkpoint_path\n",
    "\n",
    "    absolute_model_dir = '/'.join(input_checkpoint.split('/')[:-1])\n",
    "    output_graph = absolute_model_dir + '/frozen_model.pb'\n",
    "    clear_devices = True\n",
    "    with tf.Session(graph = tf.Graph()) as sess:\n",
    "        saver = tf.train.import_meta_graph(\n",
    "            input_checkpoint + '.meta', clear_devices = clear_devices\n",
    "        )\n",
    "        saver.restore(sess, input_checkpoint)\n",
    "        output_graph_def = tf.graph_util.convert_variables_to_constants(\n",
    "            sess,\n",
    "            tf.get_default_graph().as_graph_def(),\n",
    "            output_node_names.split(','),\n",
    "        )\n",
    "        with tf.gfile.GFile(output_graph, 'wb') as f:\n",
    "            f.write(output_graph_def.SerializeToString())\n",
    "        print('%d ops in the final graph.' % len(output_graph_def.node))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from alxlnet-base-entities/model.ckpt\n",
      "WARNING:tensorflow:From <ipython-input-44-9a7215a4e58a>:23: convert_variables_to_constants (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.convert_variables_to_constants`\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/framework/graph_util_impl.py:277: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.compat.v1.graph_util.extract_sub_graph`\n",
      "INFO:tensorflow:Froze 22 variables.\n",
      "INFO:tensorflow:Converted 22 variables to const ops.\n",
      "7687 ops in the final graph.\n"
     ]
    }
   ],
   "source": [
    "freeze_graph('alxlnet-base-entities', strings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Kyrgios', 'DATE'),\n",
       " ('25', 'DATE'),\n",
       " ('membuat', 'OTHER'),\n",
       " ('pesanan', 'OTHER'),\n",
       " ('itu', 'OTHER'),\n",
       " ('kerana', 'OTHER'),\n",
       " ('menyedaari', 'OTHER'),\n",
       " ('pelbagai', 'OTHER'),\n",
       " ('kesukaran', 'OTHER'),\n",
       " ('menimpa', 'OTHER'),\n",
       " ('rakyat', 'OTHER'),\n",
       " ('Australia', 'GPE'),\n",
       " ('ekoran', 'OTHER'),\n",
       " ('perintah', 'OTHER'),\n",
       " ('kawalan', 'OTHER'),\n",
       " ('pergerakan', 'OTHER'),\n",
       " ('yang', 'OTHER'),\n",
       " ('diumumkan', 'OTHER'),\n",
       " ('Mac', 'DATE'),\n",
       " ('lalu', 'DATE'),\n",
       " ('bagi', 'OTHER'),\n",
       " ('memerangi', 'OTHER'),\n",
       " ('wabak', 'OTHER'),\n",
       " ('Covid-19', 'OTHER'),\n",
       " ('di', 'OTHER'),\n",
       " ('negara', 'OTHER'),\n",
       " ('berkenaan', 'OTHER'),\n",
       " ('Pemain', 'OTHER'),\n",
       " ('tenis', 'OTHER'),\n",
       " ('ranking', 'OTHER'),\n",
       " ('ke-40', 'ORDINAL'),\n",
       " ('dunia', 'OTHER'),\n",
       " ('yang', 'OTHER'),\n",
       " ('dilahirkan', 'OTHER'),\n",
       " ('di', 'OTHER'),\n",
       " ('Canberra', 'GPE'),\n",
       " ('itu', 'OTHER'),\n",
       " ('meminta', 'OTHER'),\n",
       " ('pengikut', 'OTHER'),\n",
       " ('dan', 'OTHER'),\n",
       " ('penyokongnya', 'OTHER'),\n",
       " ('agar', 'OTHER'),\n",
       " ('jangan', 'OTHER'),\n",
       " ('tidur', 'OTHER'),\n",
       " ('dalam', 'OTHER'),\n",
       " ('keadaan', 'OTHER'),\n",
       " ('perut', 'OTHER'),\n",
       " ('kosong', 'OTHER'),\n",
       " ('dalam', 'OTHER'),\n",
       " ('hantaran', 'OTHER'),\n",
       " ('Instagram', 'ORG'),\n",
       " ('yang', 'OTHER'),\n",
       " ('meraih', 'OTHER'),\n",
       " ('lebih', 'CARDINAL'),\n",
       " ('92', 'CARDINAL'),\n",
       " ('000', 'CARDINAL'),\n",
       " ('tanda', 'OTHER'),\n",
       " ('suka', 'OTHER')]"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string = 'Kyrgios, 25, membuat pesanan itu kerana menyedaari pelbagai kesukaran menimpa rakyat Australia ekoran perintah kawalan pergerakan yang diumumkan Mac lalu bagi memerangi wabak COVID-19 di negara berkenaan. Pemain tenis ranking ke-40 dunia yang dilahirkan di Canberra itu meminta pengikut dan penyokongnya agar jangan tidur dalam keadaan perut kosong dalam hantaran Instagram yang meraih lebih 92,000 tanda suka.'\n",
    "\n",
    "sequence = entities_textcleaning(string)[1]\n",
    "parsed_sequence, segment_sequence, mask_sequence, xlnet_sequence = parse_X(sequence)\n",
    "predicted = sess.run(model.tags_seq,\n",
    "                feed_dict = {\n",
    "                    model.X: [parsed_sequence],\n",
    "                    model.segment_ids: [segment_sequence],\n",
    "                    model.input_masks: [mask_sequence],\n",
    "                },\n",
    "        )[0]\n",
    "merged = merge_sentencepiece_tokens_tagging(xlnet_sequence, [idx2tag[d] for d in predicted])\n",
    "list(zip(merged[0], merged[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.tools.graph_transforms import TransformGraph\n",
    "from glob import glob\n",
    "tf.set_random_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = ['add_default_attributes',\n",
    "             'remove_nodes(op=Identity, op=CheckNumerics, op=Dropout)',\n",
    "             'fold_batch_norms',\n",
    "             'fold_old_batch_norms',\n",
    "             'quantize_weights(fallback_min=-10, fallback_max=10)',\n",
    "             'strip_unused_nodes',\n",
    "             'sort_by_execution_order']\n",
    "\n",
    "pb = 'alxlnet-base-entities/frozen_model.pb'\n",
    "input_graph_def = tf.GraphDef()\n",
    "with tf.gfile.FastGFile(pb, 'rb') as f:\n",
    "    input_graph_def.ParseFromString(f.read())\n",
    "\n",
    "inputs = ['Placeholder', 'Placeholder_1', 'Placeholder_2']\n",
    "outputs = ['transpose_3']\n",
    "\n",
    "print(pb, inputs)\n",
    "\n",
    "transformed_graph_def = TransformGraph(input_graph_def, \n",
    "                                       inputs,\n",
    "                                       ['logits'] + outputs, transforms)\n",
    "\n",
    "with tf.gfile.GFile(f'{pb}.quantized', 'wb') as f:\n",
    "    f.write(transformed_graph_def.SerializeToString())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
